62a51b
@@ -1752,7 +1752,7 @@
public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
         // these ndvs are later used to compute unmatched rows and num of nulls for outer joins
         List<Long> ndvsUnmatched= Lists.newArrayList();
         long denom = 1;
-        long denomUnmatched = 1;
+        long distinctUnmatched = 1;
         if (inferredRowCount == -1) {
           // failed to infer PK-FK relationship for row count estimation fall-back on default logic
           // compute denominator  max(V(R,y1), V(S,y1)) * max(V(R,y2), V(S,y2))
@@ -1774,12 +1774,12 @@
public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
 
           if (numAttr > 1 && conf.getBoolVar(HiveConf.ConfVars.HIVE_STATS_CORRELATED_MULTI_KEY_JOINS)) {
             denom = Collections.max(distinctVals);
-            denomUnmatched = denom - ndvsUnmatched.get(distinctVals.indexOf(denom));
+            distinctUnmatched = denom - ndvsUnmatched.get(distinctVals.indexOf(denom));
           } else {
             // To avoid denominator getting larger and aggressively reducing
             // number of rows, we will ease out denominator.
             denom = StatsUtils.addWithExpDecay(distinctVals);
-            denomUnmatched = denom - StatsUtils.addWithExpDecay(ndvsUnmatched);
+            distinctUnmatched = denom - StatsUtils.addWithExpDecay(ndvsUnmatched);
           }
         }
 
@@ -1810,22 +1810,32 @@
public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
         // update join statistics
         stats.setColumnStats(outColStats);
 
-        long interimRowCount = inferredRowCount != -1 ? inferredRowCount
-          : computeRowCountAssumingInnerJoin(rowCounts, denom, jop);
-        // final row computation will consider join type
-        long joinRowCount = inferredRowCount != -1 ? inferredRowCount
-          : computeFinalRowCount(rowCounts, interimRowCount, jop);
-
-        // the idea is to measure unmatche rows in outer joins by figuring out how many rows didn't match
-        // mismatched rows are figured using denomUnmatched which is the difference of denom used for computing
-        // join cardinality minus the ndv which wasn't used. This number (mismatched rows) is then subtracted from
-        /// join cardinality to get the rows which didn't match
-        long unMatchedRows = Math.abs(computeRowCountAssumingInnerJoin(rowCounts, denomUnmatched, jop) - joinRowCount);
-        if(denomUnmatched == 0) {
-          // if unmatched denominator is zero we take it as all rows will match
-          unMatchedRows = 0;
+        long joinRowCount;
+        long leftUnmatchedRows = 0L;
+        long rightUnmatchedRows = 0L;
+        if (inferredRowCount != -1) {
+          joinRowCount = inferredRowCount;
+        } else {
+          long innerJoinRowCount = computeRowCountAssumingInnerJoin(rowCounts, denom, jop);
+          // the idea is to measure unmatched rows in outer joins by figuring out how many rows didn't match
+          if (jop.getConf().getConds().length == 1) {
+            // TODO: Consider more than one condition
+            JoinCondDesc joinCond = jop.getConf().getConds()[0];
+            if (joinCond.getType() == JoinDesc.LEFT_OUTER_JOIN) {
+              leftUnmatchedRows = calculateUnmatchedRowsForOuter(conf, rowCountParents.get(0), joinKeys.get(0), joinStats.get(0), distinctUnmatched);
+            } else if (joinCond.getType() == JoinDesc.RIGHT_OUTER_JOIN) {
+              rightUnmatchedRows = calculateUnmatchedRowsForOuter(conf, rowCountParents.get(1), joinKeys.get(1), joinStats.get(1), distinctUnmatched);
+            } else if (joinCond.getType() == JoinDesc.FULL_OUTER_JOIN) {
+              leftUnmatchedRows = calculateUnmatchedRowsForOuter(conf, rowCountParents.get(0), joinKeys.get(0), joinStats.get(0), distinctUnmatched);
+              rightUnmatchedRows = calculateUnmatchedRowsForOuter(conf, rowCountParents.get(1), joinKeys.get(1), joinStats.get(1), distinctUnmatched);
+            }
+          }
+          // final row computation will consider join type
+          joinRowCount = computeFinalRowCount(rowCounts, StatsUtils.safeAdd(innerJoinRowCount, StatsUtils.safeAdd(leftUnmatchedRows, rightUnmatchedRows)), jop);
         }
-        updateColStats(conf, stats, unMatchedRows, joinRowCount, jop, rowCountParents);
+
+        // update column statistics
+        updateColStats(conf, stats, leftUnmatchedRows, rightUnmatchedRows, joinRowCount, jop, rowCountParents);
 
         // evaluate filter expression and update statistics
         if (joinRowCount != -1 && jop.getConf().getNoOuterJoin() &&
@@ -1951,6 +1961,38 @@
public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
       return null;
     }
 
+    private long calculateUnmatchedRowsForOuter(HiveConf conf, long inputRowCount,
+        List<String> joinKeys, Statistics statistics, long distinctUnmatched) {
+      // Extract the ndv from each of the columns involved in the join
+      List<Long> distinctVals = new ArrayList<>();
+      for (String col: joinKeys) {
+        ColStatistics cs = statistics.getColumnStatisticsFromColName(col);
+        if (cs != null) {
+          distinctVals.add(cs.getCountDistint());
+        }
+      }
+      // Compute the number of distinct values based on configuration property
+      long distinctVal;
+      if (distinctVals.isEmpty()) {
+        distinctVal = 2L;
+      } else {
+        if (joinKeys.size() > 1 && conf.getBoolVar(HiveConf.ConfVars.HIVE_STATS_CORRELATED_MULTI_KEY_JOINS)) {
+          distinctVal = Collections.max(distinctVals);
+        } else {
+          distinctVal = StatsUtils.addWithExpDecay(distinctVals);
+        }
+      }
+      // If we have a greater number of unmatched values than number of distinct values,
+      // we just return the number of rows in the input as we can assume there are no
+      // matches
+      if (distinctUnmatched >= distinctVal) {
+        return inputRowCount;
+      }
+      // Otherwise, divide the number of input rows by the number of distinct values
+      // and divide by the number of distinct values unmatched
+      return StatsUtils.safeMult(inputRowCount / distinctVal, distinctUnmatched);
+    }
+
     private long inferPKFKRelationship(int numAttr, List<Operator<? extends OperatorDesc>> parents,
         CommonJoinOperator<? extends JoinDesc> jop) {
       long newNumRows = -1;
@@ -2222,8 +2264,8 @@
private boolean isJoinKey(final String columnName,
       return false;
     }
 
-    private void updateNumNulls(ColStatistics colStats, long unmatchedRows, long newNumRows,
-        long pos, CommonJoinOperator<? extends JoinDesc> jop) {
+    private void updateNumNulls(ColStatistics colStats, long leftUnmatchedRows, long rightUnmatchedRows,
+        long newNumRows, long pos, CommonJoinOperator<? extends JoinDesc> jop) {
 
       if (!(jop.getConf().getConds().length == 1)) {
         // TODO: handle multi joins
@@ -2236,33 +2278,28 @@
private void updateNumNulls(ColStatistics colStats, long unmatchedRows, long new
       JoinCondDesc joinCond = jop.getConf().getConds()[0];
       switch (joinCond.getType()) {
       case JoinDesc.LEFT_OUTER_JOIN:
-        //if this column is coming from right input only then we update num nulls
-        if (pos == joinCond.getRight()
-            && unmatchedRows != newNumRows) {
+        if (pos == joinCond.getRight()) {
           if (isJoinKey(colStats.getColumnName(), jop.getConf().getJoinKeys())) {
-            newNumNulls = Math.min(newNumRows, (unmatchedRows));
+            newNumNulls = Math.min(newNumRows, leftUnmatchedRows);
           } else {
-            newNumNulls = Math.min(newNumRows, oldNumNulls + (unmatchedRows));
+            newNumNulls = Math.min(newNumRows, oldNumNulls + leftUnmatchedRows);
           }
         }
         break;
       case JoinDesc.RIGHT_OUTER_JOIN:
-        if (pos == joinCond.getLeft()
-            && unmatchedRows != newNumRows) {
-
+        if (pos == joinCond.getLeft()) {
           if (isJoinKey(colStats.getColumnName(), jop.getConf().getJoinKeys())) {
-            newNumNulls = Math.min(newNumRows, ( unmatchedRows));
+            newNumNulls = Math.min(newNumRows, rightUnmatchedRows);
           } else {
-            // TODO: oldNumNulls should be scaled instead of taken as it is
-            newNumNulls = Math.min(newNumRows, oldNumNulls + (unmatchedRows));
+            newNumNulls = Math.min(newNumRows, oldNumNulls + rightUnmatchedRows);
           }
         }
         break;
       case JoinDesc.FULL_OUTER_JOIN:
         if (isJoinKey(colStats.getColumnName(), jop.getConf().getJoinKeys())) {
-          newNumNulls = Math.min(newNumRows, (unmatchedRows));
+          newNumNulls = Math.min(newNumRows, leftUnmatchedRows + rightUnmatchedRows);
         } else {
-          newNumNulls = Math.min(newNumRows, oldNumNulls + (unmatchedRows));
+          newNumNulls = Math.min(newNumRows, oldNumNulls + leftUnmatchedRows + rightUnmatchedRows);
         }
         break;
 
@@ -2274,10 +2311,8 @@
private void updateNumNulls(ColStatistics colStats, long unmatchedRows, long new
       colStats.setNumNulls(newNumNulls);
     }
 
-    private void updateColStats(HiveConf conf, Statistics stats, long interimNumRows,
-        long newNumRows,
-        CommonJoinOperator<? extends JoinDesc> jop,
-        Map<Integer, Long> rowCountParents) {
+    private void updateColStats(HiveConf conf, Statistics stats, long leftUnmatchedRows, long rightUnmatchedRows,
+        long newNumRows, CommonJoinOperator<? extends JoinDesc> jop, Map<Integer, Long> rowCountParents) {
 
       if (newNumRows < 0) {
         LOG.debug("STATS-" + jop.toString() + ": Overflow in number of rows. "
@@ -2316,7 +2351,7 @@
private void updateColStats(HiveConf conf, Statistics stats, long interimNumRows
         }
 
         cs.setCountDistint(newDV);
-        updateNumNulls(cs, interimNumRows, newNumRows, pos, jop);
+        updateNumNulls(cs, leftUnmatchedRows, rightUnmatchedRows, newNumRows, pos, jop);
       }
       stats.setColumnStats(colStats);
       long newDataSize = StatsUtils
@@ -2467,6 +2502,7 @@
private long getDenominatorForUnmatchedRows(List<Long> distinctVals) {
         return denom;
       }
     }
+
     private long getDenominator(List<Long> distinctVals) {
 
       if (distinctVals.isEmpty()) {
